{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "593f7553-7038-498e-96d4-8255e5ce34f0",
   "metadata": {},
   "source": [
    "# Custom chain\n",
    "\n",
    "To implement your own custom chain you can subclass `Chain` and implement the following methods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "from typing import Any, Dict, List, Optional\n",
    "\n",
    "from pydantic import Extra\n",
    "\n",
    "from langchain.schema.language_model import BaseLanguageModel\n",
    "from langchain.callbacks.manager import (\n",
    "    AsyncCallbackManagerForChainRun,\n",
    "    CallbackManagerForChainRun,\n",
    ")\n",
    "from langchain.chains.base import Chain\n",
    "from langchain.prompts.base import BasePromptTemplate\n",
    "\n",
    "\n",
    "class MyCustomChain(Chain):\n",
    "    \"\"\"\n",
    "    An example of a custom chain.\n",
    "    \"\"\"\n",
    "\n",
    "    prompt: BasePromptTemplate\n",
    "    \"\"\"Prompt object to use.\"\"\"\n",
    "    llm: BaseLanguageModel\n",
    "    output_key: str = \"text\"  #: :meta private:\n",
    "\n",
    "    class Config:\n",
    "        \"\"\"Configuration for this pydantic object.\"\"\"\n",
    "\n",
    "        extra = Extra.forbid\n",
    "        arbitrary_types_allowed = True\n",
    "\n",
    "    @property\n",
    "    def input_keys(self) -> List[str]:\n",
    "        \"\"\"Will be whatever keys the prompt expects.\n",
    "\n",
    "        :meta private:\n",
    "        \"\"\"\n",
    "        return self.prompt.input_variables\n",
    "\n",
    "    @property\n",
    "    def output_keys(self) -> List[str]:\n",
    "        \"\"\"Will always return text key.\n",
    "\n",
    "        :meta private:\n",
    "        \"\"\"\n",
    "        return [self.output_key]\n",
    "\n",
    "    def _call(\n",
    "        self,\n",
    "        inputs: Dict[str, Any],\n",
    "        run_manager: Optional[CallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "        # Your custom chain logic goes here\n",
    "        # This is just an example that mimics LLMChain\n",
    "        prompt_value = self.prompt.format_prompt(**inputs)\n",
    "\n",
    "        # Whenever you call a language model, or another chain, you should pass\n",
    "        # a callback manager to it. This allows the inner run to be tracked by\n",
    "        # any callbacks that are registered on the outer run.\n",
    "        # You can always obtain a callback manager for this by calling\n",
    "        # `run_manager.get_child()` as shown below.\n",
    "        response = self.llm.generate_prompt(\n",
    "            [prompt_value], callbacks=run_manager.get_child() if run_manager else None\n",
    "        )\n",
    "\n",
    "        # If you want to log something about this run, you can do so by calling\n",
    "        # methods on the `run_manager`, as shown below. This will trigger any\n",
    "        # callbacks that are registered for that event.\n",
    "        if run_manager:\n",
    "            run_manager.on_text(\"Log something about this run\")\n",
    "\n",
    "        return {self.output_key: response.generations[0][0].text}\n",
    "\n",
    "    async def _acall(\n",
    "        self,\n",
    "        inputs: Dict[str, Any],\n",
    "        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "        # Your custom chain logic goes here\n",
    "        # This is just an example that mimics LLMChain\n",
    "        prompt_value = self.prompt.format_prompt(**inputs)\n",
    "\n",
    "        # Whenever you call a language model, or another chain, you should pass\n",
    "        # a callback manager to it. This allows the inner run to be tracked by\n",
    "        # any callbacks that are registered on the outer run.\n",
    "        # You can always obtain a callback manager for this by calling\n",
    "        # `run_manager.get_child()` as shown below.\n",
    "        response = await self.llm.agenerate_prompt(\n",
    "            [prompt_value], callbacks=run_manager.get_child() if run_manager else None\n",
    "        )\n",
    "\n",
    "        # If you want to log something about this run, you can do so by calling\n",
    "        # methods on the `run_manager`, as shown below. This will trigger any\n",
    "        # callbacks that are registered for that event.\n",
    "        if run_manager:\n",
    "            await run_manager.on_text(\"Log something about this run\")\n",
    "\n",
    "        return {self.output_key: response.generations[0][0].text}\n",
    "\n",
    "    @property\n",
    "    def _chain_type(self) -> str:\n",
    "        return \"my_custom_chain\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "18361f89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
      "Log something about this run\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
    "from langchain.chat_models.openai import ChatOpenAI\n",
    "from langchain.prompts.prompt import PromptTemplate\n",
    "\n",
    "\n",
    "chain = MyCustomChain(\n",
    "    prompt=PromptTemplate.from_template(\"tell us a joke about {topic}\"),\n",
    "    llm=ChatOpenAI(),\n",
    ")\n",
    "\n",
    "chain.run({\"topic\": \"callbacks\"}, callbacks=[StdOutCallbackHandler()])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
